import numpy as np
import torch
import cv2

import ipdb

from utils.transporter_utils import get_image_transform
st = ipdb.set_trace

def _clamp_boxes(frame_boxes, H, W):
    frame_boxes = np.clip(frame_boxes, a_min=0, a_max=None)
    frame_boxes = np.minimum(
        frame_boxes, np.array([W-1, H-1, W-1, H-1]).reshape(1, 4))
    return frame_boxes

def clamp_boxes_torch(boxes, H, W):
    boxes = torch.clamp(boxes, min=0, max=None)
    boxes = torch.minimum(
        boxes, torch.tensor([W-1, H-1, W-1, H-1]).reshape(1, 4).to(boxes.device))
    return boxes

def _load_gt_boxes(detections, modes, H, W, flip=False):
    all = []
    for mode in modes:
        for k in list(detections[mode].values()):
            x1_, y1_, x2_, y2_ = k[0][0][1], k[0][0][0], k[0][1][1], k[0][1][0]
            box = np.array([x1_, y1_, x2_, y2_]).reshape(1, -1)
            box = _clamp_boxes(box, H, W).astype(np.int64).reshape(-1).tolist()
            x1_, y1_, x2_, y2_ = box
            if flip:
                box = np.array([y1_, x1_, y2_, x2_]).astype(np.int64).reshape(-1).tolist()
            else:
                box = np.array([x1_, y1_, x2_, y2_]).astype(np.int64).reshape(-1).tolist()
            all.append(box)
    return all


def flip_data(img, ground_truths):
    img = img.transpose(1, 0, 2)
    for i in range(len(ground_truths)):
        x1, y1, x2, y2 = ground_truths[i]
        ground_truths[i] = [y1, x1, y2, x2]
    return img, ground_truths


def flip_boxes(boxes):
    for i in range(len(boxes)):
        x1, y1, x2, y2 = boxes[i]
        boxes[i] = torch.tensor([y1, x1, y2, x2])
    return boxes


def flip_heatmaps(pick_hmap, place_hmap):
    pick_hmap = pick_hmap.permute(0, 2, 1)
    place_hmap = place_hmap.permute(0, 2, 1)

    return pick_hmap, place_hmap


def make_circle(boxes, center, diameter, H=320, W=640):
    assert W == H * 2
    xc, yc = center
    radius = diameter / 2.0
    n_points = len(boxes)
    new_xs = xc + np.sin(2 * np.pi * np.arange(n_points) / n_points) * radius
    new_ys = yc + np.cos(2 * np.pi * np.arange(n_points) / n_points) * radius

    # scale back to image coords
    new_xs = (new_xs + 1.0) * H
    new_ys = (new_ys + 0.5) * H

    end_points = boxes[:, 2:] - boxes[:, :2]
    boxes[:, 0] = torch.from_numpy(new_xs).to(boxes.device)
    boxes[:, 1] = torch.from_numpy(new_ys).to(boxes.device)
    boxes[:, 2:] = boxes[:, :2] + end_points
    return boxes


def make_line(boxes, center, diameter, H=320, W=640):
    assert W == H * 2
    N = len(boxes)
    end_points = boxes[:, 2:] - boxes[:, :2]
    boxes[:, 0] = (center[0] + 1.0) * H
    ycs = torch.arange(N) * diameter
    ycs = (ycs - ycs.mean()) + center[1]
    ycs = (ycs + 0.5) * H
    boxes[:, 1] = ycs.to(boxes.device)
    boxes[:, 2:] = boxes[:, :2] + end_points
    return boxes


def mask_image_with_boxes(image, boxes):
    # box is X1, Y1, X2, Y2
    # B X 4
    masked_image_ = torch.zeros_like(image)
    padding = 3 # untested

    boxes[:, :2] -= padding
    boxes[:, 2:] += padding

    image_size = image.shape
    boxes[:, (0, 2)] = torch.clamp(boxes[:, (0, 2)], 0, image_size[1])
    boxes[:, (1, 3)] = torch.clamp(boxes[:, (1, 3)], 0, image_size[0])

    for box in boxes:
        x1, y1, x2, y2 = box
        masked_image_[y1:y2, x1:x2] = image[y1:y2, x1:x2]
    
    return masked_image_
    
    